{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Extract LoRA from Fine-tune Stable Diffusion model and Deploy Model with Multiple Adapters\n",
    "\n",
    "In this tutorial, we will convert a fine-tune stable diffusion model to ONNX, and extract the LoRA adapters from the model. \n",
    "The resulting model can be deployed with multiple adapters for different tasks."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites\n",
    "\n",
    "Before running this tutorial, please ensure you already installed olive-ai. Please refer to the [installation guide](https://github.com/microsoft/Olive?tab=readme-ov-file#installation) for more information.\n",
    "\n",
    "### Install Dependencies\n",
    "We will optimize for `CUDAExecutionProvider` so `onnxruntime-gpu>=1.20` should also be installed allong with the other dependencies:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# install required packages\n",
    "!pip install -r diffusers\n",
    "\n",
    "# install onnxruntime-gpu >=1.20\n",
    "!pip uninstall -y onnxruntime onnxruntime-gpu\n",
    "!pip install \"onnxruntime-gpu>=1.20\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Workflow\n",
    "\n",
    "Let's try lovely [wolf plushie LoRA](https://huggingface.co/lora-library/B-LoRA-wolf_plushie) and [pen sketch LoRA](https://huggingface.co/lora-library/B-LoRA-pen_sketch) in this example.\n",
    "\n",
    "Olive provides command line tools to run the export and extract adapters workflow. This workflow includes the following steps:\n",
    "- `capture-onnx-graph`: Convert the fine-tuned model to ONNX\n",
    "- `generate-adapter`: Extract the adapters from the ONNX model as model inputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run this cell to see the available options to finetune, capture-onnx-graph and generate-adapter commands\n",
    "!olive capture-onnx-graph --help\n",
    "!olive generate-adapter --help"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We convert vae encoder, vae decoder, text encoder, text encoder 2 and unet with LoRAs to ONNX model first by `capture-onnx-graph` command:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert vae encoder\n",
    "!olive capture-onnx-graph -m stabilityai/stable-diffusion-xl-base-1.0 --model_script vae_encoder.py -o onnx_model/vae_encoder\n",
    "# convert vae decoder\n",
    "!olive capture-onnx-graph -m stabilityai/stable-diffusion-xl-base-1.0 --model_script vae_decoder.py -o onnx_model/vae_decoder\n",
    "# convert text encoder\n",
    "!olive capture-onnx-graph -m stabilityai/stable-diffusion-xl-base-1.0 --model_script text_encoder.py -o onnx_model/text_encoder\n",
    "# convert text encoder 2\n",
    "!olive capture-onnx-graph -m stabilityai/stable-diffusion-xl-base-1.0 --model_script text_encoder2.py -o onnx_model/text_encoder_2\n",
    "# convert unet model with wolf plushie lora\n",
    "!olive capture-onnx-graph -m stabilityai/stable-diffusion-xl-base-1.0 --model_script unet_wolf_plushie.py -o onnx_model/unet_wolf_plushie\n",
    "# convert unet model with pen sketch lora\n",
    "!olive capture-onnx-graph -m stabilityai/stable-diffusion-xl-base-1.0 --model_script unet_pen_sketch.py -o onnx_model/unet_pen_sketch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try pen sketch LoRA first:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnxruntime as ort\n",
    "import torch\n",
    "from pathlib import Path\n",
    "from optimum.onnxruntime import ORTStableDiffusionXLPipeline\n",
    "from diffusers import DiffusionPipeline, OnnxRuntimeModel\n",
    "\n",
    "ort.set_default_logger_severity(3)\n",
    "\n",
    "sess_options = ort.SessionOptions()\n",
    "sess_options.enable_mem_pattern = False\n",
    "\n",
    "model_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\n",
    "onnx_model_path = Path(\"onnx_model\")\n",
    "provider = \"CUDAExecutionProvider\"\n",
    "pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)\n",
    "\n",
    "vae_encoder_session = OnnxRuntimeModel.load_model(onnx_model_path / \"vae_encoder\" / \"model.onnx\", provider=provider)\n",
    "vae_decoder_session = OnnxRuntimeModel.load_model(onnx_model_path / \"vae_decoder\" / \"model.onnx\", provider=provider)\n",
    "text_encoder_session = OnnxRuntimeModel.load_model(onnx_model_path / \"text_encoder\" / \"model.onnx\", provider=provider)\n",
    "text_encoder_2_session = OnnxRuntimeModel.load_model(onnx_model_path / \"text_encoder_2\" / \"model\" / \"model.onnx\", provider=provider)\n",
    "\n",
    "pen_sketch_unet_session = OnnxRuntimeModel.load_model(onnx_model_path / \"unet_pen_sketch\" / \"model\" / \"model.onnx\", provider=provider)\n",
    "\n",
    "onnx_pipeline = ORTStableDiffusionXLPipeline(\n",
    "    vae_encoder_session=vae_encoder_session,\n",
    "    vae_decoder_session=vae_decoder_session,\n",
    "    text_encoder_session=text_encoder_session,\n",
    "    unet_session=pen_sketch_unet_session,\n",
    "    text_encoder_2_session=text_encoder_2_session,\n",
    "    tokenizer=pipeline.tokenizer,\n",
    "    tokenizer_2=pipeline.tokenizer_2,\n",
    "    scheduler=pipeline.scheduler,\n",
    "    feature_extractor=pipeline.feature_extractor,\n",
    "    config=dict(pipeline.config),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"A woman is dancing [v30]\"\n",
    "batch_size = 1\n",
    "result = onnx_pipeline(\n",
    "    [prompt] * batch_size,\n",
    "    num_inference_steps=20,\n",
    "    height=512,\n",
    "    width=512,\n",
    ")\n",
    "\n",
    "for image_index in range(batch_size):\n",
    "    output_path = f\"result_pen_{image_index}.png\"\n",
    "    result.images[image_index].save(output_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del pen_sketch_unet_session\n",
    "del onnx_pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then you can see the image, like this:   \n",
    "![pen sketch](image/result_pen.png)\n",
    "\n",
    "Let's try wolf plushie LoRA:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Update onnx pipeline with wolf plushie unet\n",
    "wolf_plushie_unet_session = OnnxRuntimeModel.load_model(onnx_model_path / \"unet_wolf_plushie\" / \"model\" / \"model.onnx\", provider=provider)\n",
    "\n",
    "onnx_pipeline = ORTStableDiffusionXLPipeline(\n",
    "    vae_encoder_session=vae_encoder_session,\n",
    "    vae_decoder_session=vae_decoder_session,\n",
    "    text_encoder_session=text_encoder_session,\n",
    "    unet_session=wolf_plushie_unet_session,\n",
    "    text_encoder_2_session=text_encoder_2_session,\n",
    "    tokenizer=pipeline.tokenizer,\n",
    "    tokenizer_2=pipeline.tokenizer_2,\n",
    "    scheduler=pipeline.scheduler,\n",
    "    feature_extractor=pipeline.feature_extractor,\n",
    "    config=dict(pipeline.config),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"A running wolf\"\n",
    "batch_size = 1\n",
    "result = onnx_pipeline(\n",
    "    [prompt] * batch_size,\n",
    "    num_inference_steps=20,\n",
    "    height=512,\n",
    "    width=512,\n",
    ")\n",
    "\n",
    "for image_index in range(batch_size):\n",
    "    output_path = f\"result_wolf_{image_index}.png\"\n",
    "    result.images[image_index].save(output_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del wolf_plushie_unet_session\n",
    "del onnx_pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then you can see the image, like this:   \n",
    "![wolf plushie](image/result_wolf.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, extract the adapters from the ONNX model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract adapters from pen sketch unet model\n",
    "!olive generate-adapter -m onnx_model/unet_pen_sketch -o adapters/pen_sketch\n",
    "# Extract adapters from wolf plushie unet model\n",
    "!olive generate-adapter -m onnx_model/unet_wolf_plushie -o adapters/wolf_plushie"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Deploy Model with Multiple Adapters\n",
    "\n",
    "We can now deploy the same model with multiple adapters for different tasks by loading the adapter weights independently of the model and providing the relevant weights as input at inference time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "\n",
    "file_list = [\"model.onnx\", \"model.onnx.data\"]\n",
    "old_model_path = Path(\"adapters/pen_sketch/model\")\n",
    "base_model_path = Path(\"model\")\n",
    "\n",
    "base_model_path.parent.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "# copy base onnx model to new folder\n",
    "for file in file_list:\n",
    "    shutil.copy2(old_model_path / file, base_model_path / file)\n",
    "\n",
    "base_model_path = base_model_path / \"model.onnx\"\n",
    "\n",
    "adapters = {\n",
    "    \"pen_sketch\":  \"adapters/pen_sketch/model/adapter_weights.onnx_adapter\",\n",
    "    \"wolf_plushie\": \"adapters/wolf_plushie/model/adapter_weights.onnx_adapter\"\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we implement a `UNetSessionWrapper` to inject adapter weights as inputs to the UNet model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "utils_path = Path().resolve().parent.parent.parent / \"olive\" / \"common\"\n",
    "sys.path.append(str(utils_path))\n",
    "\n",
    "from utils import load_weights\n",
    "\n",
    "class UNetSessionWrapper:\n",
    "    def __init__(self, unet_session):\n",
    "        self.unet_session = unet_session\n",
    "        self.adapters = {}\n",
    "        self.active_adapter = None\n",
    "\n",
    "    def load_adapter(self, adapter_name, adapter_path):\n",
    "        self.adapters[adapter_name] = load_weights(adapter_path)\n",
    "\n",
    "    def set_adapter(self, adapter_name):\n",
    "        assert adapter_name in self.adapters, f\"Adapter {adapter_name} not found\"\n",
    "        self.active_adapter = adapter_name\n",
    "\n",
    "    def unset_adapter(self):\n",
    "        self.active_adapter = None\n",
    "\n",
    "    def run(self, output_names, input_feed):\n",
    "        # running when adapter is not set is equivalent to running the base model\n",
    "        inputs = {**input_feed, **self.adapters.get(self.active_adapter, {})}\n",
    "        return self.unet_session.run(output_names, inputs)\n",
    "\n",
    "    def __getattr__(self, name):\n",
    "        return getattr(self.unet_session, name)\n",
    "\n",
    "\n",
    "# load unet session from base model\n",
    "unet_session = OnnxRuntimeModel.load_model(base_model_path, provider=provider)\n",
    "unet_session_wrapped = UNetSessionWrapper(unet_session)\n",
    "\n",
    "unet_session_wrapped.load_adapter(\"pen_sketch\", adapters[\"pen_sketch\"])\n",
    "unet_session_wrapped.load_adapter(\"wolf_plushie\", adapters[\"wolf_plushie\"])\n",
    "\n",
    "onnx_pipeline = ORTStableDiffusionXLPipeline(\n",
    "    vae_encoder_session=vae_encoder_session,\n",
    "    vae_decoder_session=vae_decoder_session,\n",
    "    text_encoder_session=text_encoder_session,\n",
    "    unet_session=unet_session_wrapped,\n",
    "    text_encoder_2_session=text_encoder_2_session,\n",
    "    tokenizer=pipeline.tokenizer,\n",
    "    tokenizer_2=pipeline.tokenizer_2,\n",
    "    scheduler=pipeline.scheduler,\n",
    "    feature_extractor=pipeline.feature_extractor,\n",
    "    config=dict(pipeline.config),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Generate with Pen Sketch Adapters\n",
    "Let's test pen sketch adapters first"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unet_session_wrapped.set_adapter(\"pen_sketch\")\n",
    "\n",
    "prompt = \"A dancing woman\"\n",
    "batch_size = 1\n",
    "result = onnx_pipeline(\n",
    "    [prompt] * batch_size,\n",
    "    num_inference_steps=20,\n",
    "    height=512,\n",
    "    width=512,\n",
    ")\n",
    "\n",
    "for image_index in range(batch_size):\n",
    "    output_path = f\"result_pen_merge_{image_index}.png\"\n",
    "    result.images[image_index].save(output_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then you can see the image from pen sketch adapter weight merged model, like this:   \n",
    "![pen sketch](image/result_pen_merge.png)\n",
    "\n",
    "#### Generate with Wolf Plushie Adapters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unet_session_wrapped.set_adapter(\"wolf_plushie\")\n",
    "\n",
    "prompt = \"A running wolf\"\n",
    "batch_size = 1\n",
    "result = onnx_pipeline(\n",
    "    [prompt] * batch_size,\n",
    "    num_inference_steps=20,\n",
    "    height=512,\n",
    "    width=512,\n",
    ")\n",
    "\n",
    "for image_index in range(batch_size):\n",
    "    output_path = f\"result_wolf_merge_{image_index}.png\"\n",
    "    result.images[image_index].save(output_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then you can see the image from wolf plushie adapter weight merged model, like this:   \n",
    "![wolf plushie](image/result_wolf_merge.png)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "olive",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
